{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "d2afa3e9",
      "metadata": {},
      "outputs": [],
      "source": [
        "# Evaluation utilities for the fine-tuned open-source model (Week 7)\n",
        "import re\n",
        "import math\n",
        "import numpy as np\n",
        "import torch\n",
        "import matplotlib.pyplot as plt\n",
        "\n",
        "# Extract numeric price from model output\n",
        "def extract_price(text: str) -> float:\n",
        "    text = (text or \"\").replace(\"$\", \"\").replace(\",\", \"\")\n",
        "    m = re.search(r\"[-+]?\\d*\\.\\d+|\\d+\", text)\n",
        "    return float(m.group(0)) if m else 0.0\n",
        "\n",
        "# Build prompt consistent with Week 7 training template\n",
        "def build_pricing_prompt(item) -> str:\n",
        "    # Matches the training format used in Week 7\n",
        "    return (\n",
        "        \"<|system|>\\nYou are a retail price estimator. Predict the most likely new retail price in USD.\\n\"\n",
        "        \"<|user|>\\n\"\n",
        "        f\"{item.title}\\n{item.description}\\n\"\n",
        "        \"<|assistant|>\\n\"\n",
        "    )\n",
        "\n",
        "# Single-item prediction using the fine-tuned causal LM\n",
        "@torch.no_grad()\n",
        "def predict_price(model, tokenizer, item, max_new_tokens: int = 20) -> float:\n",
        "    prompt = build_pricing_prompt(item)\n",
        "    inputs = tokenizer(prompt, return_tensors=\"pt\").to(model.device)\n",
        "    outputs = model.generate(\n",
        "        **inputs,\n",
        "        max_new_tokens=max_new_tokens,\n",
        "        temperature=0.7,\n",
        "        do_sample=True,\n",
        "        pad_token_id=tokenizer.eos_token_id,\n",
        "    )\n",
        "    decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)\n",
        "    # Take only the newly generated continuation beyond the prompt\n",
        "    continuation = decoded[len(tokenizer.decode(inputs[\"input_ids\"][0], skip_special_tokens=True)) :]\n",
        "    return extract_price(continuation)\n",
        "\n",
        "# Batch evaluation (MAE, RMSE, MAPE) with quick scatter plot\n",
        "def evaluate_model(model, tokenizer, test_items, limit: int = None, title: str = \"Fine-tuned Model Evaluation\"):\n",
        "    if not test_items:\n",
        "        print(\"⚠️ No test items available.\")\n",
        "        return {\"mae\": None, \"rmse\": None, \"mape\": None}\n",
        "\n",
        "    items = test_items[:limit] if limit else test_items\n",
        "\n",
        "    y_true, y_pred = [], []\n",
        "    for i, item in enumerate(items):\n",
        "        try:\n",
        "            pred = predict_price(model, tokenizer, item)\n",
        "        except Exception as e:\n",
        "            print(f\"Error on item {i}: {e}\")\n",
        "            pred = 0.0\n",
        "        y_true.append(float(getattr(item, \"price\", 0.0)))\n",
        "        y_pred.append(float(pred))\n",
        "\n",
        "    y_true_np = np.array(y_true, dtype=float)\n",
        "    y_pred_np = np.array(y_pred, dtype=float)\n",
        "\n",
        "    mae = float(np.mean(np.abs(y_pred_np - y_true_np)))\n",
        "    rmse = float(np.sqrt(np.mean((y_pred_np - y_true_np) ** 2)))\n",
        "    with np.errstate(divide='ignore', invalid='ignore'):\n",
        "        mape_arr = np.where(y_true_np != 0, np.abs((y_pred_np - y_true_np) / y_true_np), np.nan)\n",
        "    mape = float(np.nanmean(mape_arr)) * 100.0\n",
        "\n",
        "    print(f\"\\n📈 {title}\")\n",
        "    print(f\"MAE : {mae:.2f}\")\n",
        "    print(f\"RMSE: {rmse:.2f}\")\n",
        "    print(f\"MAPE: {mape:.2f}%\")\n",
        "\n",
        "    # Scatter plot\n",
        "    try:\n",
        "        plt.figure(figsize=(6, 6))\n",
        "        plt.scatter(y_true_np, y_pred_np, alpha=0.6)\n",
        "        mx = max(y_true_np.max() if y_true_np.size else 0, y_pred_np.max() if y_pred_np.size else 0)\n",
        "        plt.plot([0, mx], [0, mx], 'r--', label='Ideal')\n",
        "        plt.xlabel('Actual Price')\n",
        "        plt.ylabel('Predicted Price')\n",
        "        plt.title(title)\n",
        "        plt.legend()\n",
        "        plt.tight_layout()\n",
        "        plt.show()\n",
        "    except Exception as e:\n",
        "        print(f\"Plotting error: {e}\")\n",
        "\n",
        "    return {\"mae\": mae, \"rmse\": rmse, \"mape\": mape}\n",
        "\n",
        "# Convenience wrapper mirroring Week 6's Tester usage pattern\n",
        "# Usage:\n",
        "#   results = evaluate_model(model, tokenizer, test, limit=len(test))\n",
        "print(\"✅ Evaluation utilities for Week 7 added. Use evaluate_model(model, tokenizer, test, limit=len(test)).\")\n"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "c88d0ea8",
      "metadata": {
        "id": "c88d0ea8"
      },
      "source": [
        "# Week 7 - Complete Fine-tuning with Open Source LLMs\n",
        "\n",
        "This notebook implements QLoRA fine-tuning of open-source LLMs for product price prediction.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 9,
      "id": "721835a5",
      "metadata": {
        "id": "721835a5"
      },
      "outputs": [],
      "source": [
        "%pip install -q -U torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121\n",
        "%pip install -q -U transformers>=4.45.0 accelerate>=0.33.0 peft>=0.11.1 trl>=0.8.0\n",
        "%pip install -q -U datasets \"huggingface_hub>=0.23.2,<1.0\" sentencepiece einops safetensors\n",
        "%pip install -q -U bitsandbytes>=0.43.2 xformers\n",
        "%pip install -q -U wandb tensorboard"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 10,
      "id": "8a8017b0",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "8a8017b0",
        "outputId": "6c5288b6-3d15-4439-de01-ad2ff7b2b262"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "PyTorch version: 2.8.0+cu126\n",
            "CUDA available: True\n",
            "GPU: NVIDIA A100-SXM4-40GB\n",
            "GPU Memory: 42.5 GB\n",
            "CUDA version: 12.6\n"
          ]
        }
      ],
      "source": [
        "# Core imports\n",
        "import os\n",
        "import torch\n",
        "import pickle\n",
        "import numpy as np\n",
        "import json\n",
        "import re\n",
        "from datetime import datetime\n",
        "from datasets import Dataset\n",
        "from transformers import (\n",
        "    AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig,\n",
        "    TrainingArguments, Trainer, DataCollatorForLanguageModeling\n",
        ")\n",
        "from peft import LoraConfig, TaskType, get_peft_model, PeftModel\n",
        "from trl import SFTTrainer\n",
        "import transformers\n",
        "import wandb\n",
        "\n",
        "# Enable optimizations for Colab Pro\n",
        "torch.backends.cudnn.benchmark = True\n",
        "torch.backends.cuda.matmul.allow_tf32 = True\n",
        "torch.backends.cudnn.allow_tf32 = True\n",
        "\n",
        "print(f\"PyTorch version: {torch.__version__}\")\n",
        "print(f\"CUDA available: {torch.cuda.is_available()}\")\n",
        "if torch.cuda.is_available():\n",
        "    print(f\"GPU: {torch.cuda.get_device_name(0)}\")\n",
        "    print(f\"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB\")\n",
        "    print(f\"CUDA version: {torch.version.cuda}\")\n",
        "else:\n",
        "    raise SystemExit(\"❌ No GPU detected.\")\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 11,
      "id": "0b4d0cd3",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 280
        },
        "id": "0b4d0cd3",
        "outputId": "65ab54e5-4fec-4db8-e6e9-3fb86d2a13f3"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "✅ Using Colab secrets\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "Note: Environment variable`HF_TOKEN` is set and is the current active token independently from the token you've just configured.\n",
            "WARNING:huggingface_hub._login:Note: Environment variable`HF_TOKEN` is set and is the current active token independently from the token you've just configured.\n"
          ]
        },
        {
          "data": {
            "text/html": [
              "Finishing previous runs because reinit is set to 'default'."
            ],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "text/html": [],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "text/html": [
              " View run <strong style=\"color:#cdcd00\">wobbly-resonance-1</strong> at: <a href='https://wandb.ai/oluoch-joshua-udemy/colab-pro-finetuning/runs/fwkqveds' target=\"_blank\">https://wandb.ai/oluoch-joshua-udemy/colab-pro-finetuning/runs/fwkqveds</a><br> View project at: <a href='https://wandb.ai/oluoch-joshua-udemy/colab-pro-finetuning' target=\"_blank\">https://wandb.ai/oluoch-joshua-udemy/colab-pro-finetuning</a><br>Synced 5 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)"
            ],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "text/html": [
              "Find logs at: <code>./wandb/run-20251028_115212-fwkqveds/logs</code>"
            ],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "text/html": [],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "text/html": [
              "Tracking run with wandb version 0.22.2"
            ],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "text/html": [
              "Run data is saved locally in <code>/content/wandb/run-20251028_115650-rd1q63l3</code>"
            ],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "text/html": [
              "Syncing run <strong><a href='https://wandb.ai/oluoch-joshua-udemy/colab-pro-finetuning/runs/rd1q63l3' target=\"_blank\">easy-cloud-2</a></strong> to <a href='https://wandb.ai/oluoch-joshua-udemy/colab-pro-finetuning' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/developer-guide' target=\"_blank\">docs</a>)<br>"
            ],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "text/html": [
              " View project at <a href='https://wandb.ai/oluoch-joshua-udemy/colab-pro-finetuning' target=\"_blank\">https://wandb.ai/oluoch-joshua-udemy/colab-pro-finetuning</a>"
            ],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "text/html": [
              " View run at <a href='https://wandb.ai/oluoch-joshua-udemy/colab-pro-finetuning/runs/rd1q63l3' target=\"_blank\">https://wandb.ai/oluoch-joshua-udemy/colab-pro-finetuning/runs/rd1q63l3</a>"
            ],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "✅ W&B initialized\n"
          ]
        }
      ],
      "source": [
        "# Environment setup for Colab Pro\n",
        "try:\n",
        "    from google.colab import userdata\n",
        "    os.environ['HF_TOKEN'] = userdata.get('HF_TOKEN')\n",
        "    os.environ['WANDB_API_KEY'] = userdata.get('WANDB_API_KEY')\n",
        "    print(\"✅ Using Colab secrets\")\n",
        "except:\n",
        "    from dotenv import load_dotenv\n",
        "    load_dotenv(override=True)\n",
        "    os.environ['HF_TOKEN'] = os.getenv('HF_TOKEN', 'your-hf-token')\n",
        "    os.environ['WANDB_API_KEY'] = os.getenv('WANDB_API_KEY', 'your-wandb-key')\n",
        "    print(\"✅ Using local environment\")\n",
        "\n",
        "# Login to HuggingFace\n",
        "from huggingface_hub import login\n",
        "login(os.environ['HF_TOKEN'])\n",
        "\n",
        "# Initialize Weights & Biases (optional)\n",
        "try:\n",
        "    wandb.init(project=\"colab-pro-finetuning\", mode=\"online\")\n",
        "    print(\"✅ W&B initialized\")\n",
        "except:\n",
        "    print(\"⚠️  W&B not available, continuing without logging\")\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 12,
      "id": "809d2271",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "809d2271",
        "outputId": "2afd08ed-5da7-4a93-99cd-4d8f881bd0af"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "📦 Loading pre-processed pickle files...\n",
            "✅ Loaded training data: train.pkl (150 items)\n",
            "✅ Loaded test data: test.pkl (50 items)\n",
            "✅ Loaded validation data: validation.pkl (50 items)\n",
            "\n",
            "📊 Dataset Statistics:\n",
            "   Training: 150 items\n",
            "   Test: 50 items\n",
            "   Validation: 50 items\n"
          ]
        }
      ],
      "source": [
        "# Load pre-processed pickle files (optimized for Colab Pro)\n",
        "def load_pickle_data():\n",
        "    \"\"\"Load pre-processed pickle files with robust error handling\"\"\"\n",
        "    print(\"📦 Loading pre-processed pickle files...\")\n",
        "\n",
        "    # Try multiple locations for pickle files\n",
        "    pickle_files = [\n",
        "        'train.pkl', 'test.pkl', 'validation.pkl'\n",
        "    ]\n",
        "\n",
        "    train = None\n",
        "    test = None\n",
        "    validation = None\n",
        "\n",
        "    # Load training data\n",
        "    for file_path in ['train.pkl']:\n",
        "        if os.path.exists(file_path):\n",
        "            try:\n",
        "                with open(file_path, 'rb') as f:\n",
        "                    train = pickle.load(f)\n",
        "                print(f\"✅ Loaded training data: {file_path} ({len(train)} items)\")\n",
        "                break\n",
        "            except Exception as e:\n",
        "                print(f\"❌ Error loading {file_path}: {e}\")\n",
        "\n",
        "    # Load test data\n",
        "    for file_path in ['test.pkl']:\n",
        "        if os.path.exists(file_path):\n",
        "            try:\n",
        "                with open(file_path, 'rb') as f:\n",
        "                    test = pickle.load(f)\n",
        "                print(f\"✅ Loaded test data: {file_path} ({len(test)} items)\")\n",
        "                break\n",
        "            except Exception as e:\n",
        "                print(f\"❌ Error loading {file_path}: {e}\")\n",
        "\n",
        "    # Load validation data\n",
        "    for file_path in ['validation.pkl']:\n",
        "        if os.path.exists(file_path):\n",
        "            try:\n",
        "                with open(file_path, 'rb') as f:\n",
        "                    validation = pickle.load(f)\n",
        "                print(f\"✅ Loaded validation data: {file_path} ({len(validation)} items)\")\n",
        "                break\n",
        "            except Exception as e:\n",
        "                print(f\"❌ Error loading {file_path}: {e}\")\n",
        "\n",
        "    # If no pickle files found, create sample data\n",
        "    if not train or not test or not validation:\n",
        "        print(\"🔄 No pickle files found, creating sample data...\")\n",
        "        train, test, validation = create_sample_data()\n",
        "\n",
        "    return train, test, validation\n",
        "\n",
        "def create_sample_data():\n",
        "    \"\"\"Create sample data for demonstration\"\"\"\n",
        "    # Sample product data (expanded for better training)\n",
        "    sample_products = [\n",
        "        {\"title\": \"Wireless Bluetooth Headphones\", \"price\": 89.99, \"category\": \"Electronics\"},\n",
        "        {\"title\": \"Stainless Steel Water Bottle\", \"price\": 24.99, \"category\": \"Home & Kitchen\"},\n",
        "        {\"title\": \"Organic Cotton T-Shirt\", \"price\": 19.99, \"category\": \"Clothing\"},\n",
        "        {\"title\": \"Ceramic Coffee Mug\", \"price\": 12.99, \"category\": \"Home & Kitchen\"},\n",
        "        {\"title\": \"LED Desk Lamp\", \"price\": 45.99, \"category\": \"Electronics\"},\n",
        "        {\"title\": \"Yoga Mat\", \"price\": 29.99, \"category\": \"Sports & Outdoors\"},\n",
        "        {\"title\": \"Leather Wallet\", \"price\": 39.99, \"category\": \"Accessories\"},\n",
        "        {\"title\": \"Bluetooth Speaker\", \"price\": 79.99, \"category\": \"Electronics\"},\n",
        "        {\"title\": \"Kitchen Knife Set\", \"price\": 129.99, \"category\": \"Home & Kitchen\"},\n",
        "        {\"title\": \"Running Shoes\", \"price\": 89.99, \"category\": \"Sports & Outdoors\"},\n",
        "        {\"title\": \"Smartphone Case\", \"price\": 15.99, \"category\": \"Electronics\"},\n",
        "        {\"title\": \"Coffee Maker\", \"price\": 89.99, \"category\": \"Home & Kitchen\"},\n",
        "        {\"title\": \"Backpack\", \"price\": 49.99, \"category\": \"Accessories\"},\n",
        "        {\"title\": \"Tennis Racket\", \"price\": 79.99, \"category\": \"Sports & Outdoors\"},\n",
        "        {\"title\": \"Laptop Stand\", \"price\": 34.99, \"category\": \"Electronics\"}\n",
        "    ]\n",
        "\n",
        "    # Create SimpleItem objects\n",
        "    items = []\n",
        "    for product in sample_products:\n",
        "        item = SimpleItem(\n",
        "            title=product['title'],\n",
        "            description=f\"High-quality {product['title'].lower()}\",\n",
        "            price=product['price'],\n",
        "            category=product['category'],\n",
        "            token_count=len(product['title'] + f\"High-quality {product['title'].lower()}\") // 4\n",
        "        )\n",
        "        items.append(item)\n",
        "\n",
        "    # Split into train/test/validation\n",
        "    train = items[:10]  # 10 items\n",
        "    test = items[10:13]  # 3 items\n",
        "    validation = items[13:]  # 2 items\n",
        "\n",
        "    print(f\"✅ Created sample data: {len(train)} train, {len(test)} test, {len(validation)} validation\")\n",
        "    return train, test, validation\n",
        "\n",
        "# SimpleItem class definition for pickle compatibility\n",
        "class SimpleItem:\n",
        "    \"\"\"Simple item class for pickle compatibility\"\"\"\n",
        "    def __init__(self, title, description, price, category=\"Human_Generated\", token_count=0):\n",
        "        self.title = title\n",
        "        self.description = description\n",
        "        self.price = price\n",
        "        self.category = category\n",
        "        self.token_count = token_count\n",
        "\n",
        "    def test_prompt(self):\n",
        "        \"\"\"Return a prompt suitable for testing\"\"\"\n",
        "        return f\"How much does this cost to the nearest dollar?\\n\\n{self.title}\\n\\n{self.description}\\n\\nPrice is $\"\n",
        "\n",
        "    def __repr__(self):\n",
        "        return f\"SimpleItem(title='{self.title[:50]}...', price=${self.price})\"\n",
        "\n",
        "# Load the data\n",
        "train, test, validation = load_pickle_data()\n",
        "\n",
        "print(f\"\\n📊 Dataset Statistics:\")\n",
        "print(f\"   Training: {len(train)} items\")\n",
        "print(f\"   Test: {len(test)} items\")\n",
        "print(f\"   Validation: {len(validation)} items\")\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 13,
      "id": "946a3a05",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "946a3a05",
        "outputId": "41936ca5-d092-43a2-ed29-d66607af7d89"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "✅ Datasets prepared:\n",
            "   Training: 150 examples\n",
            "   Validation: 50 examples\n",
            "   Sample training text: <|system|>\n",
            "You are a retail price estimator. Predict the most likely new retail price in USD.\n",
            "<|user...\n"
          ]
        }
      ],
      "source": [
        "# Prepare datasets for training (optimized for Colab Pro)\n",
        "def prepare_training_data(items):\n",
        "    \"\"\"Convert items to training format\"\"\"\n",
        "    data = []\n",
        "    for item in items:\n",
        "        # Create training prompt\n",
        "        prompt = f\"<|system|>\\nYou are a retail price estimator. Predict the most likely new retail price in USD.\\n<|user|>\\n{item.title}\\n{item.description}\\n<|assistant|>\\n${item.price:.2f}\"\n",
        "        data.append({\"text\": prompt})\n",
        "    return data\n",
        "\n",
        "# Prepare training and validation datasets\n",
        "train_data = prepare_training_data(train)\n",
        "val_data = prepare_training_data(validation)\n",
        "\n",
        "# Convert to HuggingFace datasets\n",
        "train_ds = Dataset.from_list(train_data)\n",
        "val_ds = Dataset.from_list(val_data)\n",
        "\n",
        "print(f\"✅ Datasets prepared:\")\n",
        "print(f\"   Training: {len(train_ds)} examples\")\n",
        "print(f\"   Validation: {len(val_ds)} examples\")\n",
        "print(f\"   Sample training text: {train_ds[0]['text'][:100]}...\")\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 40,
      "id": "zWgL4fhku_XN",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 289,
          "referenced_widgets": [
            "245570c62c3844728d7125a706fbbc9b",
            "8d95da3803e542f8b855175013d497ba",
            "34a89db126a64690bc2f6c8656ba2210",
            "4c47ce21b5a14328aa22403782e4da9b",
            "7dff366d9e71427dbae40b1dce7a9bfa",
            "0614c35b3690494ca3b8f9ab71d71a08",
            "42315e83fbac49c2bc7f2faf1abcc22e",
            "4cdff5bdf7574795802e821aa42f3c4e",
            "754aa440f45c4a878d99572368d659c8",
            "8b5f0c156a9641cfa5413668a0b97b9c",
            "aff498bd632f4036958f59cfc6587ea3",
            "d6825cc926a24f2482ce72c15242081e",
            "24b2b5f5d92049a79014b8278e97451b",
            "a6001d34e58a47cab0d8bff2451afb6e",
            "b42e8d8b61d7431a814a03c5e07a1166",
            "9ce1659c776140bcaf3c16eae6f70967",
            "cceae79c145d4b73a64e80ad3fc8866c",
            "56bc56071ff04223935dc2d98d2703ab",
            "67cacc87afe14250baaa073289fb4a8f",
            "227eea7074544adbb2c34b9dde340fa5",
            "8bd9aebb2cc5420094b2b441a5183523",
            "1c3eb3793b6e4291b4fa57ce8419ef1f"
          ]
        },
        "id": "zWgL4fhku_XN",
        "outputId": "8d375a13-59cf-4eea-f16f-fde5dbf7f0e8"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "🔄 Checking dataset status...\n",
            "Training dataset columns: ['input_ids', 'attention_mask', 'labels']\n",
            "Validation dataset columns: ['input_ids', 'attention_mask', 'labels']\n",
            "✅ Datasets already tokenized\n",
            "🔄 Ensuring consistent sequence lengths...\n"
          ]
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "245570c62c3844728d7125a706fbbc9b",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Map:   0%|          | 0/150 [00:00<?, ? examples/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "d6825cc926a24f2482ce72c15242081e",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Map:   0%|          | 0/50 [00:00<?, ? examples/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "🔍 Verifying sequence lengths...\n",
            "Training sequence lengths - Min: 256, Max: 256\n",
            "Validation sequence lengths - Min: 256, Max: 256\n",
            "✅ All sequences have consistent length\n",
            "Sample input_ids shape: torch.Size([256])\n",
            "Sample attention_mask shape: torch.Size([256])\n",
            "Sample labels shape: torch.Size([256])\n"
          ]
        }
      ],
      "source": [
        "# Tokenize datasets for causal LM (creates input_ids, attention_mask, labels)\n",
        "MAX_LEN = 256  # Further reduced for stability\n",
        "\n",
        "def tokenize_function(examples):\n",
        "    # Tokenize with padding and truncation\n",
        "    outputs = tokenizer(\n",
        "        examples[\"text\"],\n",
        "        truncation=True,\n",
        "        max_length=MAX_LEN,\n",
        "        padding=\"max_length\",  # Pad to max_length\n",
        "        return_tensors=None,   # Return lists, not tensors\n",
        "    )\n",
        "    # Labels are the shifted inputs for causal LM\n",
        "    outputs[\"labels\"] = outputs[\"input_ids\"].copy()\n",
        "    return outputs\n",
        "\n",
        "def ensure_consistent_lengths(dataset, max_len):\n",
        "    \"\"\"Ensure all sequences in dataset have consistent length\"\"\"\n",
        "    def pad_sequences(examples):\n",
        "        # Convert to lists if they're tensors\n",
        "        input_ids = []\n",
        "        attention_masks = []\n",
        "        labels = []\n",
        "\n",
        "        for i in range(len(examples[\"input_ids\"])):\n",
        "            # Get the sequence and convert to list if tensor\n",
        "            seq = examples[\"input_ids\"][i]\n",
        "            attn = examples[\"attention_mask\"][i]\n",
        "            lbl = examples[\"labels\"][i]\n",
        "\n",
        "            # Convert tensors to lists\n",
        "            if hasattr(seq, 'tolist'):\n",
        "                seq = seq.tolist()\n",
        "            if hasattr(attn, 'tolist'):\n",
        "                attn = attn.tolist()\n",
        "            if hasattr(lbl, 'tolist'):\n",
        "                lbl = lbl.tolist()\n",
        "\n",
        "            # Truncate if too long\n",
        "            if len(seq) > max_len:\n",
        "                seq = seq[:max_len]\n",
        "                attn = attn[:max_len]\n",
        "                lbl = lbl[:max_len]\n",
        "\n",
        "            # Pad if too short\n",
        "            while len(seq) < max_len:\n",
        "                seq.append(tokenizer.pad_token_id)\n",
        "                attn.append(0)  # 0 for padding\n",
        "                lbl.append(-100)  # -100 for padding in labels (ignored in loss)\n",
        "\n",
        "            input_ids.append(seq)\n",
        "            attention_masks.append(attn)\n",
        "            labels.append(lbl)\n",
        "\n",
        "        return {\n",
        "            \"input_ids\": input_ids,\n",
        "            \"attention_mask\": attention_masks,\n",
        "            \"labels\": labels\n",
        "        }\n",
        "\n",
        "    return dataset.map(pad_sequences, batched=True)\n",
        "\n",
        "print(\"🔄 Checking dataset status...\")\n",
        "print(f\"Training dataset columns: {train_ds.column_names}\")\n",
        "print(f\"Validation dataset columns: {val_ds.column_names}\")\n",
        "\n",
        "# Check if we need to tokenize or just ensure consistent lengths\n",
        "if \"text\" in train_ds.column_names:\n",
        "    print(\"🔄 Tokenizing datasets...\")\n",
        "    train_ds = train_ds.map(tokenize_function, batched=True, remove_columns=[\"text\"])\n",
        "    val_ds = val_ds.map(tokenize_function, batched=True, remove_columns=[\"text\"])\n",
        "    print(\"✅ Tokenization complete\")\n",
        "else:\n",
        "    print(\"✅ Datasets already tokenized\")\n",
        "\n",
        "# Ensure consistent lengths\n",
        "print(\"🔄 Ensuring consistent sequence lengths...\")\n",
        "train_ds = ensure_consistent_lengths(train_ds, MAX_LEN)\n",
        "val_ds = ensure_consistent_lengths(val_ds, MAX_LEN)\n",
        "\n",
        "# Verify all sequences are the same length\n",
        "print(\"🔍 Verifying sequence lengths...\")\n",
        "train_lengths = [len(seq) for seq in train_ds[\"input_ids\"]]\n",
        "val_lengths = [len(seq) for seq in val_ds[\"input_ids\"]]\n",
        "\n",
        "print(f\"Training sequence lengths - Min: {min(train_lengths)}, Max: {max(train_lengths)}\")\n",
        "print(f\"Validation sequence lengths - Min: {min(val_lengths)}, Max: {max(val_lengths)}\")\n",
        "\n",
        "if len(set(train_lengths)) == 1 and len(set(val_lengths)) == 1:\n",
        "    print(\"✅ All sequences have consistent length\")\n",
        "else:\n",
        "    print(\"⚠️  Inconsistent sequence lengths detected - this will cause training errors\")\n",
        "\n",
        "# Set format for PyTorch\n",
        "train_ds.set_format(type=\"torch\", columns=[\"input_ids\", \"attention_mask\", \"labels\"])\n",
        "val_ds.set_format(type=\"torch\", columns=[\"input_ids\", \"attention_mask\", \"labels\"])\n",
        "\n",
        "print(f\"Sample input_ids shape: {train_ds[0]['input_ids'].shape}\")\n",
        "print(f\"Sample attention_mask shape: {train_ds[0]['attention_mask'].shape}\")\n",
        "print(f\"Sample labels shape: {train_ds[0]['labels'].shape}\")\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 41,
      "id": "55a3b346",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "55a3b346",
        "outputId": "34ba52e3-3a22-4c60-e18b-18592c0b1d80"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Loading tokenizer...\n",
            "✅ Tokenizer loaded successfully\n",
            "Loading base model (4-bit optimized for Colab Pro)...\n",
            "⚠️  Error with 4-bit quantization: Using `bitsandbytes` 4-bit quantization requires the latest version of bitsandbytes: `pip install -U bitsandbytes`\n",
            "🔄 Trying without quantization...\n",
            "✅ Model loaded without quantization\n",
            "Model device: cuda:0\n",
            "Model dtype: torch.float16\n"
          ]
        }
      ],
      "source": [
        "# Model setup optimized for Colab Pro\n",
        "# Using a more compatible model that works well with current transformers\n",
        "base_model = \"microsoft/DialoGPT-medium\"  # More stable and widely supported\n",
        "\n",
        "# 4-bit quantization config optimized for Colab Pro\n",
        "bnb_config = BitsAndBytesConfig(\n",
        "    load_in_4bit=True,\n",
        "    bnb_4bit_quant_type=\"nf4\",\n",
        "    bnb_4bit_use_double_quant=True,\n",
        "    bnb_4bit_compute_dtype=torch.float16,\n",
        ")\n",
        "\n",
        "print(\"Loading tokenizer...\")\n",
        "try:\n",
        "    tokenizer = AutoTokenizer.from_pretrained(base_model, use_fast=True, trust_remote_code=True)\n",
        "    print(\"✅ Tokenizer loaded successfully\")\n",
        "except Exception as e:\n",
        "    print(f\"⚠️  Error loading tokenizer: {e}\")\n",
        "    print(\"🔄 Trying alternative approach...\")\n",
        "    tokenizer = AutoTokenizer.from_pretrained(base_model, use_fast=False, trust_remote_code=False)\n",
        "\n",
        "tokenizer.pad_token = tokenizer.eos_token\n",
        "tokenizer.padding_side = \"right\"\n",
        "\n",
        "print(\"Loading base model (4-bit optimized for Colab Pro)...\")\n",
        "try:\n",
        "    model = AutoModelForCausalLM.from_pretrained(\n",
        "        base_model,\n",
        "        quantization_config=bnb_config,\n",
        "        device_map=\"auto\",\n",
        "        low_cpu_mem_usage=True,\n",
        "        trust_remote_code=True,\n",
        "        torch_dtype=torch.float16,\n",
        "    )\n",
        "    print(\"✅ Model loaded successfully\")\n",
        "except Exception as e:\n",
        "    print(f\"⚠️  Error with 4-bit quantization: {e}\")\n",
        "    print(\"🔄 Trying without quantization...\")\n",
        "    model = AutoModelForCausalLM.from_pretrained(\n",
        "        base_model,\n",
        "        device_map=\"auto\",\n",
        "        low_cpu_mem_usage=True,\n",
        "        torch_dtype=torch.float16,\n",
        "    )\n",
        "    print(\"✅ Model loaded without quantization\")\n",
        "\n",
        "print(f\"Model device: {next(model.parameters()).device}\")\n",
        "print(f\"Model dtype: {next(model.parameters()).dtype}\")\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "QCcujyKNyTud",
      "metadata": {
        "id": "QCcujyKNyTud"
      },
      "outputs": [],
      "source": [
        "from peft import prepare_model_for_kbit_training\n",
        "\n",
        "# disable cache for gradient checkpointing\n",
        "model.config.use_cache = False\n",
        "\n",
        "# enable gradient checkpointing\n",
        "model.gradient_checkpointing_enable()\n",
        "\n",
        "# IMPORTANT: prepare for k-bit training (sets up norms, cast, etc.)\n",
        "model = prepare_model_for_kbit_training(model)\n",
        "\n",
        "# ensure inputs carry grads for checkpointing\n",
        "if hasattr(model, \"enable_input_require_grads\"):\n",
        "    model.enable_input_require_grads()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 48,
      "id": "a3o-5dxDr5MH",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "a3o-5dxDr5MH",
        "outputId": "778c6094-3fbc-468f-dbec-7d1798791f04"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "trainable params: 6,291,456 || all params: 361,114,624 || trainable%: 1.7422\n",
            "✅ LoRA configuration applied for GPT-2/DialoGPT modules\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "/usr/local/lib/python3.12/dist-packages/peft/mapping_func.py:73: UserWarning: You are trying to modify a model with PEFT for a second time. If you want to reload the model with a different config, make sure to call `.unload()` before.\n",
            "  warnings.warn(\n",
            "/usr/local/lib/python3.12/dist-packages/peft/tuners/tuners_utils.py:196: UserWarning: Already found a `peft_config` attribute in the model. This will lead to having multiple adapters in the model. Make sure to know what you are doing!\n",
            "  warnings.warn(\n"
          ]
        }
      ],
      "source": [
        "# LoRA configuration compatible with GPT-2/DialoGPT modules\n",
        "from peft import LoraConfig, get_peft_model, TaskType\n",
        "\n",
        "# For GPT-2/DialoGPT, target modules typically are c_attn (QKV), c_fc and c_proj (MLP)\n",
        "lora_config = LoraConfig(\n",
        "    r=16,\n",
        "    lora_alpha=32,\n",
        "    lora_dropout=0.05,\n",
        "    bias=\"none\",\n",
        "    task_type=TaskType.CAUSAL_LM,\n",
        "    target_modules=[\"c_attn\", \"c_fc\", \"c_proj\"],\n",
        ")\n",
        "\n",
        "# Apply LoRA to model\n",
        "model = get_peft_model(model, lora_config)\n",
        "model.print_trainable_parameters()\n",
        "print(\"✅ LoRA configuration applied for GPT-2/DialoGPT modules\")\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 49,
      "id": "ac85c418",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "ac85c418",
        "outputId": "196789bc-c66f-4fbd-eb66-65a50b3cf995"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "✅ Training arguments configured !!\n"
          ]
        }
      ],
      "source": [
        "# Training arguments\n",
        "training_args = TrainingArguments(\n",
        "    output_dir=\"./outputs\",\n",
        "    per_device_train_batch_size=2,\n",
        "    per_device_eval_batch_size=2,\n",
        "    gradient_accumulation_steps=8,\n",
        "    num_train_epochs=3,\n",
        "    learning_rate=2e-4,\n",
        "    bf16=True,\n",
        "    logging_steps=10,\n",
        "    eval_strategy=\"steps\",\n",
        "    eval_steps=50,\n",
        "    save_steps=100,\n",
        "    save_total_limit=3,\n",
        "    lr_scheduler_type=\"cosine\",\n",
        "    warmup_ratio=0.03,\n",
        "    gradient_checkpointing=True,\n",
        "    dataloader_pin_memory=False,\n",
        "    remove_unused_columns=False,\n",
        "    report_to=[\"wandb\"] if os.environ.get('WANDB_API_KEY') else [],\n",
        "    seed=42,\n",
        "    # Colab Pro optimizations\n",
        "    dataloader_num_workers=2,\n",
        "    save_safetensors=True,\n",
        "    load_best_model_at_end=True,\n",
        "    metric_for_best_model=\"eval_loss\",\n",
        "    greater_is_better=False,\n",
        ")\n",
        "\n",
        "print(\"✅ Training arguments configured !!\")\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 50,
      "id": "b7452949",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "b7452949",
        "outputId": "c84d0df8-efef-4423-ab1f-bc343766b386"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "✅ Custom data collator for pre-padded sequences configured\n"
          ]
        }
      ],
      "source": [
        "# Custom data collator for pre-padded sequences\n",
        "# Since we already padded during tokenization, we just need to stack tensors\n",
        "def custom_collate_fn(batch):\n",
        "    \"\"\"Custom collate function for pre-padded sequences\"\"\"\n",
        "    # Extract the fields we need\n",
        "    input_ids = torch.stack([torch.tensor(item[\"input_ids\"]) for item in batch])\n",
        "    attention_mask = torch.stack([torch.tensor(item[\"attention_mask\"]) for item in batch])\n",
        "    labels = torch.stack([torch.tensor(item[\"labels\"]) for item in batch])\n",
        "\n",
        "    return {\n",
        "        \"input_ids\": input_ids,\n",
        "        \"attention_mask\": attention_mask,\n",
        "        \"labels\": labels\n",
        "    }\n",
        "\n",
        "# Use our custom collator\n",
        "data_collator = custom_collate_fn\n",
        "\n",
        "print(\"✅ Custom data collator for pre-padded sequences configured\")\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 51,
      "id": "qFVD1QGmxgv4",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "qFVD1QGmxgv4",
        "outputId": "591ac10d-f8e2-461a-a629-9617bb6a120a"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "✅ Data collator configured\n"
          ]
        }
      ],
      "source": [
        "# Data collator for language modeling\n",
        "data_collator = DataCollatorForLanguageModeling(\n",
        "    tokenizer=tokenizer,\n",
        "    mlm=False,  # We're doing causal LM, not masked LM\n",
        "    pad_to_multiple_of=8,  # Optimize for GPU\n",
        ")\n",
        "\n",
        "print(\"✅ Data collator configured\")\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 52,
      "id": "fc57cf22",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "fc57cf22",
        "outputId": "b83c8108-13ac-4114-a4a9-7beb26963ee5"
      },
      "outputs": [
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "/tmp/ipython-input-3978596696.py:2: FutureWarning: `tokenizer` is deprecated and will be removed in version 5.0.0 for `Trainer.__init__`. Use `processing_class` instead.\n",
            "  trainer = Trainer(\n",
            "The model is already on multiple devices. Skipping the move to device specified in `args`.\n"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "✅ Trainer configured\n",
            "Training examples: 150\n",
            "Validation examples: 50\n",
            "Total training steps: 27\n"
          ]
        }
      ],
      "source": [
        "# Create trainer\n",
        "trainer = Trainer(\n",
        "    model=model,\n",
        "    args=training_args,\n",
        "    train_dataset=train_ds,\n",
        "    eval_dataset=val_ds,\n",
        "    data_collator=data_collator,\n",
        "    tokenizer=tokenizer,\n",
        ")\n",
        "\n",
        "print(\"✅ Trainer configured\")\n",
        "print(f\"Training examples: {len(train_ds)}\")\n",
        "print(f\"Validation examples: {len(val_ds)}\")\n",
        "print(f\"Total training steps: {len(train_ds) // training_args.per_device_train_batch_size // training_args.gradient_accumulation_steps * training_args.num_train_epochs}\")\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 53,
      "id": "547502bd",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 197
        },
        "id": "547502bd",
        "outputId": "11530f9b-1a40-4353-dfa2-6caa4d4ff22e"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "🚀 Starting training...\n",
            "Training on: NVIDIA A100-SXM4-40GB\n",
            "Batch size: 2\n",
            "Gradient accumulation: 8\n",
            "Effective batch size: 16\n"
          ]
        },
        {
          "data": {
            "text/html": [
              "\n",
              "    <div>\n",
              "      \n",
              "      <progress value='30' max='30' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
              "      [30/30 00:40, Epoch 3/3]\n",
              "    </div>\n",
              "    <table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              " <tr style=\"text-align: left;\">\n",
              "      <th>Step</th>\n",
              "      <th>Training Loss</th>\n",
              "      <th>Validation Loss</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "  </tbody>\n",
              "</table><p>"
            ],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "✅ Training completed!\n",
            "Model saved to: ./outputs\n"
          ]
        }
      ],
      "source": [
        "# Start training\n",
        "print(\"🚀 Starting training...\")\n",
        "print(f\"Training on: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'}\")\n",
        "print(f\"Batch size: {training_args.per_device_train_batch_size}\")\n",
        "print(f\"Gradient accumulation: {training_args.gradient_accumulation_steps}\")\n",
        "print(f\"Effective batch size: {training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps}\")\n",
        "\n",
        "# Train the model\n",
        "trainer.train()\n",
        "\n",
        "print(\"✅ Training completed!\")\n",
        "print(f\"Model saved to: {training_args.output_dir}\")\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 54,
      "id": "a4df3b21",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "a4df3b21",
        "outputId": "3cb4ac09-9b61-4ab8-bc11-cf480cb764be"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "✅ Model and tokenizer saved\n",
            "Saved to: ./outputs\n",
            "Mounted at /content/drive\n",
            "✅ Model also saved to Google Drive: /content/drive/MyDrive/Colab Notebooks/finetuned_model_20251028_123003\n"
          ]
        }
      ],
      "source": [
        "# Save the final model\n",
        "trainer.save_model()\n",
        "tokenizer.save_pretrained(training_args.output_dir)\n",
        "\n",
        "print(\"✅ Model and tokenizer saved\")\n",
        "print(f\"Saved to: {training_args.output_dir}\")\n",
        "\n",
        "# Save to Google Drive (optional)\n",
        "try:\n",
        "    from google.colab import drive\n",
        "    drive.mount('/content/drive')\n",
        "\n",
        "    # Copy to Drive\n",
        "    import shutil\n",
        "    drive_path = f\"/content/drive/MyDrive/Colab Notebooks/finetuned_model_{datetime.now().strftime('%Y%m%d_%H%M%S')}\"\n",
        "    shutil.copytree(training_args.output_dir, drive_path)\n",
        "    print(f\"✅ Model also saved to Google Drive: {drive_path}\")\n",
        "except:\n",
        "    print(\"⚠️  Google Drive not available, model saved locally only\")\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 55,
      "id": "e2507760",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 228
        },
        "id": "e2507760",
        "outputId": "03dda8b6-711c-4bee-b4ed-b00aedde5ab4"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "📊 Evaluating model...\n",
            "⚠️  Best checkpoint not found, using final model\n"
          ]
        },
        {
          "data": {
            "text/html": [
              "\n",
              "    <div>\n",
              "      \n",
              "      <progress value='25' max='25' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
              "      [25/25 00:01]\n",
              "    </div>\n",
              "    "
            ],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\n",
            "📈 Evaluation Results:\n",
            "   eval_loss: 5.8997\n",
            "   eval_runtime: 1.5263\n",
            "   eval_samples_per_second: 32.7600\n",
            "   eval_steps_per_second: 16.3800\n",
            "   epoch: 3.0000\n",
            "\n",
            "✅ Evaluation completed!\n"
          ]
        }
      ],
      "source": [
        "# Evaluate the model\n",
        "print(\"📊 Evaluating model...\")\n",
        "\n",
        "# Load the best model\n",
        "best_model_path = f\"{training_args.output_dir}/checkpoint-best\"\n",
        "if os.path.exists(best_model_path):\n",
        "    model = PeftModel.from_pretrained(model, best_model_path)\n",
        "    print(\"✅ Loaded best checkpoint\")\n",
        "else:\n",
        "    print(\"⚠️  Best checkpoint not found, using final model\")\n",
        "\n",
        "# Run evaluation\n",
        "eval_results = trainer.evaluate()\n",
        "print(f\"\\n📈 Evaluation Results:\")\n",
        "for key, value in eval_results.items():\n",
        "    print(f\"   {key}: {value:.4f}\")\n",
        "\n",
        "print(\"\\n✅ Evaluation completed!\")\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 56,
      "id": "c80bebe1",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "c80bebe1",
        "outputId": "75b9cf98-cae9-48c0-de41-eb245f574e0c"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "🧪 Testing inference...\n",
            "\n",
            "--- Test 1 ---\n",
            "Item: MyCableMart 3.5mm Plug/Jack, 4 Conductor TRRS, Self Solder, Male\n",
            "Actual Price: $25.00\n",
            "Model Response: <|system|>\n",
            "You are a retail price estimator. Predict the most likely new retail price in USD.\n",
            "<|user|>\n",
            "MyCableMart 3.5mm Plug/Jack, 4 Conductor TRRS, Self Solder, Male\n",
            "Connects stereo audio & microphone devices requiring 4 conductors (left and right audio and microphone plus ground). This connector MAY also be suitable for left/right audio 1 video (composite) and ground. Great for making your own 3.5mm 4 conductor Cables or for repairing existing cables. Wire terminals are attached using solder (not included).Features 3.5mm 4 conductor (3 band) plug 3.5mm 4 conductor (3 band) plug Nickel Plated Nickel Plated Strain relief Strain relief Outer Dimensions (at PVC outer molding) Outer Dimensions (at PVC outer molding) Outer Dimensions (with PVC outer molding\n",
            "<|assistant|>\n",
            "input.5, 3.00,,5,2,2,2,2,2\n",
            "\n",
            "--- Test 2 ---\n",
            "Item: OtterBox + Pop Symmetry Series Case for iPhone 11 Pro (ONLY) - Retail Packaging - White Marble\n",
            "Actual Price: $20.00\n",
            "Model Response: <|system|>\n",
            "You are a retail price estimator. Predict the most likely new retail price in USD.\n",
            "<|user|>\n",
            "OtterBox + Pop Symmetry Series Case for iPhone 11 Pro (ONLY) - Retail Packaging - White Marble\n",
            "OtterBox + Pop Symmetry Series Case for iPhone 11 Pro (ONLY) - Retail Packaging - White Marble Compatible with iPhone 11 Pro Thin one-piece case with durable protection against drops, bumps and fumbles that is also compatible with Qi wireless charging PopSockets PopGrip is integrated into case to help with holding, texting, snapping better pictures and hand-free viewing PopTop designs are easy to switch out — just close flat, press down and turn to swap the PopTop. Includes OtterBox limited lifetime warranty (see website for details) and 100% authentic Dimensions 7.8 x 4.29 x 1.06 inches, Weight 3\n",
            "<|assistant|>\n",
            "Type.html,.html.html,, Material, width, material, material, material,\n",
            "\n",
            "--- Test 3 ---\n",
            "Item: Dell XPS Desktop ( Intel Core i7 4790 (3.6 GHz), 8GB, 1TB HDD,Windows 10 Home Black\n",
            "Actual Price: $500.00\n",
            "Model Response: <|system|>\n",
            "You are a retail price estimator. Predict the most likely new retail price in USD.\n",
            "<|user|>\n",
            "Dell XPS Desktop ( Intel Core i7 4790 (3.6 GHz), 8GB, 1TB HDD,Windows 10 Home Black\n",
            "Product description Bring your multimedia to life with Dell XPS desktop PCs offering powerful processors, superb graphics performance and lots of storage space. Amazon.com Processor 4th Generation Intel Core processor (8M Cache, up to 4.00 GHz) OS Windows 7 Professional, English Graphics Card NVIDIA GeForce GTX 750Ti 2GB DDR5 Memory 32GB Dual Channel DDR3 - 4 DIMMs Hard Drive 1TB 7200 RPM SATA Hard Drive 6.0 Gb/s + 256GB SSD Processor 3.6 GHz RAM 8 GB DDR5, Memory Speed 1600 MHz,\n",
            "<|assistant|>\n",
            "USB HDD,RAM, HDD.8GB2,USB HDD,USB HDD,USB HDD,\n",
            "\n",
            "✅ Inference testing completed!\n"
          ]
        }
      ],
      "source": [
        "# Test inference on sample data\n",
        "print(\"🧪 Testing inference...\")\n",
        "\n",
        "def test_inference(model, tokenizer, test_item):\n",
        "    \"\"\"Test inference on a single item\"\"\"\n",
        "    prompt = f\"<|system|>\\nYou are a retail price estimator. Predict the most likely new retail price in USD.\\n<|user|>\\n{test_item.title}\\n{test_item.description}\\n<|assistant|>\\n\"\n",
        "\n",
        "    inputs = tokenizer(prompt, return_tensors=\"pt\").to(model.device)\n",
        "\n",
        "    with torch.no_grad():\n",
        "        outputs = model.generate(\n",
        "            **inputs,\n",
        "            max_new_tokens=20,\n",
        "            temperature=0.7,\n",
        "            do_sample=True,\n",
        "            pad_token_id=tokenizer.eos_token_id\n",
        "        )\n",
        "\n",
        "    response = tokenizer.decode(outputs[0], skip_special_tokens=True)\n",
        "    return response\n",
        "\n",
        "# Test on a few examples\n",
        "for i, item in enumerate(test[:3]):\n",
        "    print(f\"\\n--- Test {i+1} ---\")\n",
        "    print(f\"Item: {item.title}\")\n",
        "    print(f\"Actual Price: ${item.price:.2f}\")\n",
        "\n",
        "    try:\n",
        "        response = test_inference(model, tokenizer, item)\n",
        "        print(f\"Model Response: {response}\")\n",
        "    except Exception as e:\n",
        "        print(f\"Error: {e}\")\n",
        "\n",
        "print(\"\\n✅ Inference testing completed!\")\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "4e716982",
      "metadata": {},
      "outputs": [],
      "source": [
        "# Fixed evaluation with price range constraints and better post-processing\n",
        "import numpy as np\n",
        "import matplotlib.pyplot as plt\n",
        "import re\n",
        "import torch\n",
        "\n",
        "def extract_price_safe(text: str) -> float:\n",
        "    \"\"\"Extract price with safety constraints\"\"\"\n",
        "    if not text:\n",
        "        return 0.0\n",
        "    \n",
        "    # Clean the text\n",
        "    text = str(text).replace(\"$\", \"\").replace(\",\", \"\").strip()\n",
        "    \n",
        "    # Look for price patterns\n",
        "    patterns = [\n",
        "        r'\\$?(\\d+\\.?\\d*)\\s*(?:dollars?|USD|usd)?',  # $123.45 or 123.45 dollars\n",
        "        r'(\\d+\\.?\\d*)',  # Just numbers\n",
        "    ]\n",
        "    \n",
        "    for pattern in patterns:\n",
        "        matches = re.findall(pattern, text, re.IGNORECASE)\n",
        "        if matches:\n",
        "            try:\n",
        "                price = float(matches[0])\n",
        "                # Apply reasonable price constraints\n",
        "                if 0.01 <= price <= 100000:  # Between 1 cent and $100k\n",
        "                    return price\n",
        "            except ValueError:\n",
        "                continue\n",
        "    \n",
        "    return 0.0\n",
        "\n",
        "def build_pricing_prompt_fixed(item) -> str:\n",
        "    \"\"\"Build prompt with explicit price range guidance\"\"\"\n",
        "    return (\n",
        "        \"<|system|>\\n\"\n",
        "        \"You are a retail price estimator. Predict the most likely new retail price in USD. \"\n",
        "        \"Typical prices range from $1 to $10,000. Be realistic and conservative.\\n\"\n",
        "        \"<|user|>\\n\"\n",
        "        f\"Product: {item.title}\\n\"\n",
        "        f\"Description: {item.description}\\n\"\n",
        "        f\"Category: {getattr(item, 'category', 'Unknown')}\\n\"\n",
        "        \"What is the retail price?\\n\"\n",
        "        \"<|assistant|>\\n\"\n",
        "        \"The retail price is $\"\n",
        "    )\n",
        "\n",
        "@torch.no_grad()\n",
        "def predict_price_fixed(model, tokenizer, item, max_new_tokens=15) -> float:\n",
        "    \"\"\"Predict price with better constraints\"\"\"\n",
        "    prompt = build_pricing_prompt_fixed(item)\n",
        "    inputs = tokenizer(prompt, return_tensors=\"pt\").to(model.device)\n",
        "    \n",
        "    # Generate with more conservative settings\n",
        "    outputs = model.generate(\n",
        "        **inputs,\n",
        "        max_new_tokens=max_new_tokens,\n",
        "        temperature=0.3,  # Lower temperature for more conservative predictions\n",
        "        do_sample=True,\n",
        "        pad_token_id=tokenizer.eos_token_id,\n",
        "        repetition_penalty=1.1,  # Reduce repetition\n",
        "        no_repeat_ngram_size=2,\n",
        "    )\n",
        "    \n",
        "    # Decode only the new tokens\n",
        "    prompt_length = len(tokenizer.decode(inputs[\"input_ids\"][0], skip_special_tokens=True))\n",
        "    full_response = tokenizer.decode(outputs[0], skip_special_tokens=True)\n",
        "    new_text = full_response[prompt_length:]\n",
        "    \n",
        "    # Extract price with constraints\n",
        "    price = extract_price_safe(new_text)\n",
        "    \n",
        "    # Additional safety: if price is still unreasonable, use a fallback\n",
        "    if price > 50000:  # If over $50k, it's probably wrong\n",
        "        # Try to extract a more reasonable number\n",
        "        numbers = re.findall(r'\\d+\\.?\\d*', new_text)\n",
        "        if numbers:\n",
        "            try:\n",
        "                # Take the first reasonable number\n",
        "                for num in numbers:\n",
        "                    candidate = float(num)\n",
        "                    if 1 <= candidate <= 10000:\n",
        "                        return candidate\n",
        "            except ValueError:\n",
        "                pass\n",
        "        return 0.0\n",
        "    \n",
        "    return price\n",
        "\n",
        "def evaluate_model_fixed(model, tokenizer, test_items, limit=None, title=\"Fixed Fine-tuned Model\"):\n",
        "    \"\"\"Evaluate with fixed price extraction\"\"\"\n",
        "    if not test_items:\n",
        "        print(\"⚠️ No test items available.\")\n",
        "        return {\"mae\": None, \"rmse\": None, \"mape\": None}\n",
        "    \n",
        "    items = test_items[:limit] if limit else test_items\n",
        "    print(f\"🔍 Evaluating on {len(items)} items...\")\n",
        "    \n",
        "    y_true, y_pred = [], []\n",
        "    errors = []\n",
        "    \n",
        "    for i, item in enumerate(items):\n",
        "        try:\n",
        "            pred = predict_price_fixed(model, tokenizer, item)\n",
        "            true_price = float(getattr(item, \"price\", 0.0))\n",
        "            \n",
        "            y_true.append(true_price)\n",
        "            y_pred.append(pred)\n",
        "            \n",
        "            # Track individual errors for debugging\n",
        "            error = abs(pred - true_price)\n",
        "            errors.append({\n",
        "                'item': i,\n",
        "                'title': getattr(item, 'title', 'Unknown')[:50],\n",
        "                'true': true_price,\n",
        "                'pred': pred,\n",
        "                'error': error\n",
        "            })\n",
        "            \n",
        "        except Exception as e:\n",
        "            print(f\"Error on item {i}: {e}\")\n",
        "            y_true.append(0.0)\n",
        "            y_pred.append(0.0)\n",
        "    \n",
        "    y_true = np.array(y_true, dtype=float)\n",
        "    y_pred = np.array(y_pred, dtype=float)\n",
        "    \n",
        "    # Calculate metrics\n",
        "    mae = float(np.mean(np.abs(y_pred - y_true)))\n",
        "    rmse = float(np.sqrt(np.mean((y_pred - y_true) ** 2)))\n",
        "    \n",
        "    # MAPE (avoid division by zero)\n",
        "    mape = float(np.mean(np.abs((y_true - y_pred) / np.maximum(y_true, 1.0)))) * 100\n",
        "    \n",
        "    # Hits within 15% tolerance\n",
        "    tolerance = 0.15\n",
        "    hits = float(np.mean(np.abs(y_pred - y_true) <= (tolerance * np.maximum(y_true, 1.0)))) * 100\n",
        "    \n",
        "    # Create scatter plot\n",
        "    plt.figure(figsize=(8, 6))\n",
        "    plt.scatter(y_true, y_pred, alpha=0.7, s=30, c='blue')\n",
        "    \n",
        "    # Add diagonal line\n",
        "    max_val = max(y_true.max() if y_true.size else 0, y_pred.max() if y_pred.size else 0, 1)\n",
        "    plt.plot([0, max_val], [0, max_val], 'r--', alpha=0.8, label='Perfect Prediction')\n",
        "    \n",
        "    plt.xlabel('True Price ($)')\n",
        "    plt.ylabel('Predicted Price ($)')\n",
        "    plt.title(f'{title}\\nMAE=${mae:.2f} RMSE=${rmse:.2f} MAPE={mape:.1f}% Hits={hits:.1f}%')\n",
        "    plt.legend()\n",
        "    plt.grid(True, alpha=0.3)\n",
        "    plt.tight_layout()\n",
        "    plt.show()\n",
        "    \n",
        "    # Show worst predictions\n",
        "    errors.sort(key=lambda x: x['error'], reverse=True)\n",
        "    print(f\"\\n🔍 Top 5 Worst Predictions:\")\n",
        "    for i, err in enumerate(errors[:5]):\n",
        "        print(f\"  {i+1}. {err['title']}...\")\n",
        "        print(f\"     True: ${err['true']:.2f}, Pred: ${err['pred']:.2f}, Error: ${err['error']:.2f}\")\n",
        "    \n",
        "    return {\n",
        "        \"mae\": mae,\n",
        "        \"rmse\": rmse, \n",
        "        \"mape\": mape,\n",
        "        \"hits_pct\": hits,\n",
        "        \"y_true\": y_true,\n",
        "        \"y_pred\": y_pred,\n",
        "        \"errors\": errors\n",
        "    }\n",
        "\n",
        "# Test the fixed evaluation\n",
        "print(\"🧪 Testing fixed price prediction...\")\n",
        "results = evaluate_model_fixed(model, tokenizer, test, limit=20, title=\"Fixed Fine-tuned Model\")\n"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "gpuType": "A100",
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    },
    "widgets": {
      "application/vnd.jupyter.widget-state+json": {
        "0614c35b3690494ca3b8f9ab71d71a08": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "1c3eb3793b6e4291b4fa57ce8419ef1f": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "227eea7074544adbb2c34b9dde340fa5": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "ProgressStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "245570c62c3844728d7125a706fbbc9b": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HBoxModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_8d95da3803e542f8b855175013d497ba",
              "IPY_MODEL_34a89db126a64690bc2f6c8656ba2210",
              "IPY_MODEL_4c47ce21b5a14328aa22403782e4da9b"
            ],
            "layout": "IPY_MODEL_7dff366d9e71427dbae40b1dce7a9bfa"
          }
        },
        "24b2b5f5d92049a79014b8278e97451b": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HTMLModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_cceae79c145d4b73a64e80ad3fc8866c",
            "placeholder": "​",
            "style": "IPY_MODEL_56bc56071ff04223935dc2d98d2703ab",
            "value": "Map: 100%"
          }
        },
        "34a89db126a64690bc2f6c8656ba2210": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "FloatProgressModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_4cdff5bdf7574795802e821aa42f3c4e",
            "max": 150,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_754aa440f45c4a878d99572368d659c8",
            "value": 150
          }
        },
        "42315e83fbac49c2bc7f2faf1abcc22e": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "4c47ce21b5a14328aa22403782e4da9b": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HTMLModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_8b5f0c156a9641cfa5413668a0b97b9c",
            "placeholder": "​",
            "style": "IPY_MODEL_aff498bd632f4036958f59cfc6587ea3",
            "value": " 150/150 [00:00&lt;00:00, 1904.80 examples/s]"
          }
        },
        "4cdff5bdf7574795802e821aa42f3c4e": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "56bc56071ff04223935dc2d98d2703ab": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "67cacc87afe14250baaa073289fb4a8f": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "754aa440f45c4a878d99572368d659c8": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "ProgressStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "7dff366d9e71427dbae40b1dce7a9bfa": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "8b5f0c156a9641cfa5413668a0b97b9c": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "8bd9aebb2cc5420094b2b441a5183523": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "8d95da3803e542f8b855175013d497ba": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HTMLModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_0614c35b3690494ca3b8f9ab71d71a08",
            "placeholder": "​",
            "style": "IPY_MODEL_42315e83fbac49c2bc7f2faf1abcc22e",
            "value": "Map: 100%"
          }
        },
        "9ce1659c776140bcaf3c16eae6f70967": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "a6001d34e58a47cab0d8bff2451afb6e": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "FloatProgressModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_67cacc87afe14250baaa073289fb4a8f",
            "max": 50,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_227eea7074544adbb2c34b9dde340fa5",
            "value": 50
          }
        },
        "aff498bd632f4036958f59cfc6587ea3": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "b42e8d8b61d7431a814a03c5e07a1166": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HTMLModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_8bd9aebb2cc5420094b2b441a5183523",
            "placeholder": "​",
            "style": "IPY_MODEL_1c3eb3793b6e4291b4fa57ce8419ef1f",
            "value": " 50/50 [00:00&lt;00:00, 1509.88 examples/s]"
          }
        },
        "cceae79c145d4b73a64e80ad3fc8866c": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "d6825cc926a24f2482ce72c15242081e": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HBoxModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_24b2b5f5d92049a79014b8278e97451b",
              "IPY_MODEL_a6001d34e58a47cab0d8bff2451afb6e",
              "IPY_MODEL_b42e8d8b61d7431a814a03c5e07a1166"
            ],
            "layout": "IPY_MODEL_9ce1659c776140bcaf3c16eae6f70967"
          }
        }
      }
    }
  },
  "nbformat": 4,
  "nbformat_minor": 5
}
